import copy

import time

import numpy as np
import torch
import gym

from gym import spaces, Wrapper, ObservationWrapper
from gym.wrappers import Monitor

from multiworld.core.image_env import ImageEnv
from multiworld.core.wrapper_env import ProxyEnv





def make_dispatch(env_type,*args,**kwargs):
    if env_type == "mujoco":
        return make_envs_mujoco(*args, **kwargs)
    elif env_type =="gridworld":
        return make_envs_gridworld(*args,**kwargs)
    elif env_type == "maze":
        return make_envs_maze(*args, **kwargs)
    elif env_type == "multiworld":
        return make_envs_multiworld(*args,**kwargs)
    raise Exception("environment error")

def make_envs_gridworld(args,save_dir,name="",video=False,stats=None,render_stats=False,one_goal=False,pos=False,random_pos=False,manual_start_pos=None,goal=None,**kwargs):
    # if stats is not True and goal is not None:
    #     print(stats[2, 27])

    if stats is not None and one_goal:
        stats = np.expand_dims(stats[:,:],axis=2)

    from envs.empty import HardWall
    from gym_minigrid.wrappers import OneFullyFlatObsWrapper, PosWrapper, SimpleActionWrapper

    envs = HardWall(goal=goal,stats=stats,render_stats=render_stats,max_steps=args.max_steps,size=30,pos=pos,random_pos=random_pos,manual_start_pos=manual_start_pos)
    if args.image:
        envs = OneFullyFlatObsWrapper(envs)
    else:
        envs = PosWrapper(envs)
    envs=TimeLimitSpe(SimpleActionWrapper(envs),max_episode_steps=args.max_steps)
    envs=WrapPyTorch(envs,"cpu",squeeze=True)
    if save_dir is not None and video:
        envs=Monitor(envs,save_dir+"/"+name)
    return envs



def make_envs_multiworld(args,save_dir,name="",video=False,reset=True):
    from multiworld.envs.mujoco import register_mujoco_envs
    register_mujoco_envs()
    init_camera = None
    if args.env_name=="sawyer_reach":
        from multiworld.envs.mujoco.sawyer_xyz.sawyer_reach import SawyerReachXYZEnv
        envs =SawyerReachXYZEnv(fix_goal=True)
    elif args.env_name=="sawyer_door":
        from multiworld.envs.mujoco.cameras import sawyer_door_env_camera_v0
        envs = gym.make("SawyerDoorHookResetFreeEnv-v1")
        init_camera=sawyer_door_env_camera_v0
    elif args.env_name=="sawyer_pickup":
        from multiworld.envs.mujoco.cameras import sawyer_pick_and_place_camera
        envs = gym.make("SawyerPickupEnvYZEasy-v0")
        init_camera=sawyer_pick_and_place_camera
    elif args.env_name=="sawyer_push":
        from multiworld.envs.mujoco.cameras import sawyer_init_camera_zoomed_in
         # envs = [lambda: TimeLimitSpe(FlatGoalEnv(gym.make("SawyerPushNIPSEasy-v0")), max_episode_steps=max_steps) for _ in range(num_processes)]
        envs =gym.make("SawyerPushNIPSEasy-v0")
        init_camera=sawyer_init_camera_zoomed_in
    else:
        raise Exception("env name error")

    if args.image:
        envs =ImageNotFlat(ImageEnv(envs, non_presampled_goal_img_is_garbage=(args.env_name == "sawyer_door" or args.env_name == "sawyer_pickup"), imsize=48,transpose=True,normalize=False,init_camera=init_camera))#, obs_keys=['image_observation'])
        # envs =ImageNotFlat(ImageEnv(envs, non_presampled_goal_img_is_garbage=(env_name == "sawyer_door" or env_name == "sawyer_pickup"), imsize=48,transpose=True,normalize=False))#, obs_keys=['image_observation'])
    envs=TimeLimitSpe(envs, max_episode_steps=args.max_steps,reset=reset)
    envs=WrapPyTorch(envs,"cpu")
    if save_dir is not None and video:
        envs=Monitor(envs,save_dir+"/"+name)
    envs.seed(args.seed)
    envs.action_space.np_random.seed(args.seed)
    envs.action_space.seed(args.seed)
    envs.observation_space.seed(args.seed)

    return envs


def make_envs_mujoco(args,save_dir,name="",video=False,**kwargs):

    if args.env_name=="humanoid":
        from gym.envs.mujoco import HumanoidEnv
        envs=HumanoidEnv()
    elif args.env_name=="ant":
        from gym.envs.mujoco import AntEnv
        envs=AntReduce(AntEnv())
    elif args.env_name == "hopper":
        from gym.envs.mujoco import HopperEnv
        envs = HopperEnv()
    elif args.env_name=="walker2d":
        from gym.envs.mujoco import Walker2dEnv
        envs=Walker2dEnv()
    elif args.env_name=="halfcheetah":
        from gym.envs.mujoco import HalfCheetahEnv
        envs = HalfCheetahEnv()
    else:
        raise Exception("env name error")

    envs=TimeLimitSpe(envs, max_episode_steps=args.max_steps)
    envs=WrapPyTorch(envs,"cpu")
    if save_dir is not None and video:
        envs=Monitor(envs,save_dir+"/"+name)
    envs.seed(args.seed)
    envs.action_space.np_random.seed(args.seed)
    envs.action_space.seed(args.seed)
    envs.observation_space.seed(args.seed)

    return envs

def make_envs_maze(args,save_dir,name="",video=False,reset=True):

    import envs.mujoco_maze.ant_maze_env

    envs=gym.make(args.env_name)
    envs=TimeLimitSpe(envs, max_episode_steps=args.max_steps,reset=reset)
    envs=WrapPyTorch(envs,"cpu")
    if save_dir is not None and video:
        envs=Monitor(envs,save_dir+"/"+name)
    envs.seed(args.seed)
    envs.action_space.np_random.seed(args.seed)
    envs.action_space.seed(args.seed)
    envs.observation_space.seed(args.seed)

    return envs

class WrapPyTorch(Wrapper):
    def __init__(self, env, device,squeeze=False):
        """Return only every `skip`-th frame"""
        super(WrapPyTorch, self).__init__(env)
        self.device = device
        self.squeeze=squeeze
        self.metadata = {
            'render.modes': ['rgb_array'],
        }

    def reset(self):
        obs = self.env.reset()
        # if obs.dtype != np.uint8:
        if isinstance(obs, dict):
            dtype = torch.uint8 if obs["observation"].dtype != np.float and obs["observation"].dtype != np.float32 else torch.float
            obs["observation"] = torch.from_numpy(obs["observation"]).to(dtype=dtype).view((1,)+obs["observation"].shape)
        else:
            dtype = torch.uint8 if obs.dtype != np.float and obs.dtype != np.float32 else torch.float
            obs = torch.from_numpy(obs).unsqueeze(0).to(dtype=dtype)
        # else:
        #     obs = torch.from_numpy(obs).to(self.device)
        return obs

    def step(self, actions):
        if not self.squeeze:
            actions = actions.view(-1).cpu().numpy()#remove the squeeze
        else:
            actions = actions.item()
        obs, reward, done, info =self.env.step(actions)
        if isinstance(obs, dict):
            dtype = torch.uint8 if obs["observation"].dtype != np.float and obs["observation"].dtype != np.float32 else torch.float
            obs["observation"] = torch.from_numpy(obs["observation"]).to(self.device,dtype=dtype).view((1,)+obs["observation"].shape)
            if not ("true_obs" in info) or info["true_obs"] is None:
                info["true_obs"] = None
            else:
                info["true_obs"]= torch.from_numpy(info["true_obs"]).to(self.device,dtype=dtype).view(obs["observation"].shape).squeeze(0)
        else:
            dtype = torch.uint8 if obs.dtype != np.float and obs.dtype != np.float32 else torch.float
            obs = torch.from_numpy(obs).to(device=self.device,dtype=dtype).unsqueeze(0)
            info["true_obs"] = None if not("true_obs" in info) or info["true_obs"] is None else torch.from_numpy(info["true_obs"]).to(self.device,dtype=dtype)
        # reward  = torch.from_numpy(reward).float().to(self.device)
        reward = torch.tensor(reward).float().unsqueeze(-1).to(self.device)
        return obs, reward, done, info

    def seed(self, seed=None):
        self.env.seed(seed)

class AntReduce(ObservationWrapper):

    def __init__(self, env):
        super().__init__(env)
        env.observation_space = gym.spaces.Box(np.array([-np.inf]*27),np.array([np.inf]*27),dtype=np.float32)

    def observation(self, observation):
        return observation


class TimeLimitSpe(Wrapper):
    def __init__(self, env, max_episode_seconds=None, max_episode_steps=None,reset=True):
        super(TimeLimitSpe, self).__init__(env)
        self._max_episode_seconds = max_episode_seconds
        self._max_episode_steps = max_episode_steps
        self.need_reset=reset
        self._elapsed_steps = 0
        self._episode_started_at = None

    @property
    def _elapsed_seconds(self):
        return time.time() - self._episode_started_at

    def _past_limit(self):
        if self._max_episode_steps == -1:
            return False

        """Return true if we are past our limit"""
        if self._max_episode_steps is not None and self._max_episode_steps <= self._elapsed_steps:
            return True

        if self._max_episode_seconds is not None and self._max_episode_seconds <= self._elapsed_seconds:
            return True

        return False

    def step(self, action):
        assert self._episode_started_at is not None, "Cannot call env.step() before calling reset()"

        observation, reward, done, info = self.env.step(action)
        self._elapsed_steps += 1
        #TRUE_RESET: le done est pris en charge par la fin du coordintateur et le coordintaur les prend tous en compte.
        #TRUE_RESET: la dernier état compte pour 0 pour le sublearner. RESET ANYWAY, ca compte ue   sur pour le coordinator
        info["past"] = False
        info["over"] = False
        info["true_obs"] = None

        if done:
            info["over"]=True
        if self._past_limit():# or done:
            info["past"]=True

        #info["state"] = observation["state"]
        if self.need_reset and (self._past_limit() or done):
            if isinstance(observation,dict):
                if "state" in observation:
                    info["true_state"] = observation["state"]
                info["true_obs"] = observation["observation"]
            else:
                info["true_obs"] = np.copy(observation)
                # info["true_state"] = torch.as_tensor(np.copy(observation))
            observation = self.reset()
            done=True
        # observation["state"update_ne] = observation["observation"]
        return observation, reward, done, info

    def reset(self):
        self._episode_started_at = time.time()
        self._elapsed_steps = 0
        obs = self.env.reset()
        # obs["state"]=obs["observation"]
        return obs


class ImageNotFlat(ProxyEnv):
    def __init__(self,wrapped_env):
        super(ImageNotFlat, self).__init__(wrapped_env)
        self.metadata = {
            'render.modes': ['rgb_array'],
        }
        if isinstance(self.wrapped_env.observation_space, gym.spaces.Dict):
            dict_observation_space = copy.deepcopy(self.wrapped_env.observation_space.spaces)
            dict_observation_space["observation"]=spaces.Box(
                low=0.,
                high=1.,
                shape=(3,48,48),
                dtype='float32'
            )
            self.observation_space = gym.spaces.Dict(dict_observation_space)
        else:
            self.observation_space=spaces.Box(
                low=0.,
                high=1.,
                shape=(3,48,48),
                dtype='float32'
            )
        self.image=None

    def step(self,action):
        obs, reward, done, info= self.wrapped_env.step(action)
        self.image=None

        if isinstance(obs, dict):
            obs["observation"] = obs["observation"].reshape(3,48,48)
        else:
            obs = obs.reshape(3,48,48)
        return obs, reward, done, info

    def reset(self):
        obs= self.wrapped_env.reset()
        self.image=None
        if isinstance(obs, dict):
            obs["observation"] = obs["observation"].reshape(3,48,48)
        else:
            obs = obs.reshape(3,48,48)
        return obs

    def render(self,mode="rgb_array"):
        if self.image is None:
            self.image= self.wrapped_env.render(mode=mode)
        return self.image

def sawyer_door_env_camera_v0(camera):
    camera.distance = 0.25
    camera.lookat[0] = -.2
    camera.lookat[1] = 0.55
    camera.lookat[2] = 0.6
    camera.elevation = -60
    camera.azimuth = 360
    camera.trackbodyid = -1

def sawyer_pick_and_place_camera(camera):
    camera.lookat[0] = 0.0
    camera.lookat[1] = .67
    camera.lookat[2] = .1
    camera.distance = .7
    camera.elevation = 0
    camera.azimuth = 180
    camera.trackbodyid = 0

def sawyer_init_camera_zoomed_in(camera):
    camera.trackbodyid = 0
    camera.distance = 1.0

    # 3rd person view
    cam_dist = 0.3
    rotation_angle = 270
    cam_pos = np.array([0, 0.85, 0.2, cam_dist, -45, rotation_angle])

    for i in range(3):
        camera.lookat[i] = cam_pos[i]
    camera.distance = cam_pos[3]
    camera.elevation = cam_pos[4]
    camera.azimuth = cam_pos[5]
    camera.trackbodyid = -1